Code
library(tidyverse, verbose = FALSE)
library(tidymodels, verbose = FALSE)
library(reticulate)
library(ggplot2)
library(plotly)
library(RColorBrewer)
library(bslib)
library(Metrics)
reticulate::use_virtualenv("r-tf")Simone Brazzi
August 12, 2024
In prediction time, il modello deve ritornare un vettore contenente un 1 o uno 0 in corrispondenza di ogni label presente nel dataset (toxic, severe_toxic, obscene, threat, insult, identity_hate). In questo modo, un commento non dannoso sarà classificato da un vettore di soli 0 [0,0,0,0,0,0]. Al contrario, un commento pericoloso presenterà almeno un 1 tra le 6 labels.
Leveraging Quarto and RStudio, I will setup an R and Python enviroment.
Import R libraries. These will be used for both the rendering of the document and data analysis. The reason is I prefer ggplot2 over matplotlib. I will also use colorblind safe palettes.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
import keras_nlp
from keras.backend import clear_session
from keras.models import Model, load_model
from keras.layers import TextVectorization, Input, Dense, Embedding, Dropout, GlobalAveragePooling1D, LSTM, Bidirectional, GlobalMaxPool1D, Flatten, Attention
from keras.metrics import Precision, Recall, AUC, SensitivityAtSpecificity, SpecificityAtSensitivity, F1Score
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import multilabel_confusion_matrix, classification_report, ConfusionMatrixDisplay, precision_recall_curve, f1_score, recall_score, roc_auc_scoreCreate a Config class to store all the useful parameters for the model and for the project.
I created a class with all the basic configuration of the model, to improve the readability.
class Config():
def __init__(self):
self.url = "https://s3.eu-west-3.amazonaws.com/profession.ai/datasets/Filter_Toxic_Comments_dataset.csv"
self.max_tokens = 20000
self.output_sequence_length = 911 # check the analysis done to establish this value
self.embedding_dim = 128
self.batch_size = 32
self.epochs = 100
self.temp_split = 0.3
self.test_split = 0.5
self.random_state = 42
self.total_samples = 159571 # total train samples
self.train_samples = 111699
self.val_samples = 23936
self.features = 'comment_text'
self.labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
self.new_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate', "clean"]
self.label_mapping = {label: i for i, label in enumerate(self.labels)}
self.new_label_mapping = {label: i for i, label in enumerate(self.labels)}
self.path = "/Users/simonebrazzi/R/blog/posts/toxic_comment_filter/history/f1score/"
self.model = self.path + "model_f1.keras"
self.checkpoint = self.path + "checkpoint.lstm_model_f1.keras"
self.history = self.path + "lstm_model_f1.xlsx"
self.metrics = [
Precision(name='precision'),
Recall(name='recall'),
AUC(name='auc', multi_label=True, num_labels=len(self.labels)),
F1Score(name="f1", average="macro")
]
def get_early_stopping(self):
early_stopping = keras.callbacks.EarlyStopping(
monitor="val_f1", # "val_recall",
min_delta=0.2,
patience=10,
verbose=0,
mode="max",
restore_best_weights=True,
start_from_epoch=3
)
return early_stopping
def get_model_checkpoint(self, filepath):
model_checkpoint = keras.callbacks.ModelCheckpoint(
filepath=filepath,
monitor="val_f1", # "val_recall",
verbose=0,
save_best_only=True,
save_weights_only=False,
mode="max",
save_freq="epoch"
)
return model_checkpoint
def find_optimal_threshold_cv(self, ytrue, yproba, metric, thresholds=np.arange(.05, .35, .05), n_splits=7):
# instantiate KFold
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
threshold_scores = []
for threshold in thresholds:
cv_scores = []
for train_index, val_index in kf.split(ytrue):
ytrue_val = ytrue[val_index]
yproba_val = yproba[val_index]
ypred_val = (yproba_val >= threshold).astype(int)
score = metric(ytrue_val, ypred_val, average="macro")
cv_scores.append(score)
mean_score = np.mean(cv_scores)
threshold_scores.append((threshold, mean_score))
# Find the threshold with the highest mean score
best_threshold, best_score = max(threshold_scores, key=lambda x: x[1])
return best_threshold, best_score
config = Config()The dataset is accessible using tf.keras.utils.get_file to get the file from the url. N.B. For reproducibility purpose, I also downloaded the dataset. There was time in which the link was not available.
# A tibble: 5 × 8
comment_text toxic severe_toxic obscene threat insult identity_hate
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 "Explanation\nWhy the … 0 0 0 0 0 0
2 "D'aww! He matches thi… 0 0 0 0 0 0
3 "Hey man, I'm really n… 0 0 0 0 0 0
4 "\"\nMore\nI can't mak… 0 0 0 0 0 0
5 "You, sir, are my hero… 0 0 0 0 0 0
# ℹ 1 more variable: sum_injurious <dbl>
Lets create a clean variable for EDA purpose: I want to visually see how many observation are clean vs the others labels.
First a check on the dataset to find possible missing values and imbalances.
library(reticulate)
df_r <- py$df
new_labels_r <- py$config$new_labels
df_r_grouped <- df_r %>%
select(all_of(new_labels_r)) %>%
pivot_longer(
cols = all_of(new_labels_r),
names_to = "label",
values_to = "value"
) %>%
group_by(label) %>%
summarise(count = sum(value)) %>%
mutate(freq = round(count / sum(count), 4))
df_r_grouped# A tibble: 7 × 3
label count freq
<chr> <dbl> <dbl>
1 clean 143346 0.803
2 identity_hate 1405 0.0079
3 insult 7877 0.0441
4 obscene 8449 0.0473
5 severe_toxic 1595 0.0089
6 threat 478 0.0027
7 toxic 15294 0.0857
library(reticulate)
barchart <- df_r_grouped %>%
ggplot(aes(x = reorder(label, count), y = count, fill = label)) +
geom_col() +
labs(
x = "Labels",
y = "Count"
) +
# sort bars in descending order
scale_x_discrete(limits = df_r_grouped$label[order(df_r_grouped$count, decreasing = TRUE)]) +
scale_fill_brewer(type = "seq", palette = "RdYlBu") +
theme_minimal()
ggplotly(barchart)It is visible how much the dataset in imbalanced. This means it could be useful to check for the class weight and use this argument during the training.
It is clear that most of our text are clean. We are talking about 0.8033 of the observations which are clean. Only 0.1967 are toxic comments.
To convert the text in a useful input for a NN, it is necessary to use a TextVectorization layer. See the Section 4 section.
One of the method is output_sequence_length: to better define it, it is useful to analyze our text length. To simulate what the model we do, we are going to remove the punctuation and the new lines from the comments.
# A tibble: 1 × 6
Min. `1st Qu.` Median Mean `3rd Qu.` Max.
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 4 91 196 378. 419 5000
library(reticulate)
boxplot <- df_r %>%
mutate(
comment_text_clean = comment_text %>%
tolower() %>%
str_remove_all("[[:punct:]]") %>%
str_replace_all("\n", " "),
text_length = comment_text_clean %>% str_count()
) %>%
# pull(text_length) %>%
ggplot(aes(y = text_length)) +
geom_boxplot() +
coord_flip() +
theme_minimal()
ggplotly(boxplot)library(reticulate)
df_ <- df_r %>%
mutate(
comment_text_clean = comment_text %>%
tolower() %>%
str_remove_all("[[:punct:]]") %>%
str_replace_all("\n", " "),
text_length = comment_text_clean %>% str_count()
)
Q1 <- quantile(df_$text_length, 0.25)
Q3 <- quantile(df_$text_length, 0.75)
IQR <- Q3 - Q1
upper_fence <- as.integer(Q3 + 1.5 * IQR)
histogram <- df_ %>%
ggplot(aes(x = text_length)) +
geom_histogram(bins = 50) +
geom_vline(aes(xintercept = upper_fence), color = "red", linetype = "dashed", linewidth = 1) +
theme_minimal() +
xlab("Text Length") +
ylab("Frequency") +
xlim(0, max(df_$text_length, upper_fence))
ggplotly(histogram)Considering all the above analysis, I think a good starting value for the output_sequence_length is 911, the upper fence of the boxplot. In the last plot, it is the dashed red vertical line.. Doing so, we are removing the outliers, which are a small part of our dataset.
Now we can split the dataset in 3: train, test and validation sets. Considering there is not a function in sklearn which lets split in these 3 sets, we can do the following: - split between a train and temporary set with a 0.3 split. - split the temporary set in 2 equal sized test and val sets.
x = df[config.features].values
y = df[config.labels].values
xtrain, xtemp, ytrain, ytemp = train_test_split(
x,
y,
test_size=config.temp_split, # .3
random_state=config.random_state
)
xtest, xval, ytest, yval = train_test_split(
xtemp,
ytemp,
test_size=config.test_split, # .5
random_state=config.random_state
)xtrain shape: py$xtrain.shape ytrain shape: py$ytrain.shape xtest shape: py$xtest.shape ytest shape: py$ytest.shape xval shape: py$xval.shape yval shape: py$yval.shape
The datasets are created using the tf.data.Dataset function. It creates a data input pipeline. The tf.data API makes it possible to handle large amounts of data, read from different data formats, and perform complex transformations. The tf.data.Dataset is an abstraction that represents a sequence of elements, in which each element consists of one or more components. Here each dataset is creates using from_tensor_slices. It create a tf.data.Dataset from a tuple (features, labels). .batch let us work in batches to improve performance, while .prefetch overlaps the preprocessing and model execution of a training step. While the model is executing training step s, the input pipeline is reading the data for step s+1. Check the documentation for further informations.
train_ds = (
tf.data.Dataset
.from_tensor_slices((xtrain, ytrain))
.shuffle(xtrain.shape[0])
.batch(config.batch_size)
.prefetch(tf.data.experimental.AUTOTUNE)
)
test_ds = (
tf.data.Dataset
.from_tensor_slices((xtest, ytest))
.batch(config.batch_size)
.prefetch(tf.data.experimental.AUTOTUNE)
)
val_ds = (
tf.data.Dataset
.from_tensor_slices((xval, yval))
.batch(config.batch_size)
.prefetch(tf.data.experimental.AUTOTUNE)
)train_ds cardinality: 3491
val_ds cardinality: 748
test_ds cardinality: 748
Check the first element of the dataset to be sure that the preprocessing is done correctly.
(array([b'I will force you to eat dog poop. \n\nIt will be funny watching you gag and vomit with crap smeared all over your face.',
b", I'm glad I'm not the only one looking at this; I saw the latest series of edits this morning and they needed reverting too. The lack of communication doesn't help, of course. Thanks also for picking up the baton on the article and user talk pages, and bringing up the name issue.",
b'Neologism \nIt certainly qualifies as a neologism, but I dont have time or interest to research who actually coined it. That needs to be added. -|t 4 July 2005 19:34 (UTC)',
b'Please stop. If you continue to vandalize Wikipedia, as you did to Cheetos, you will be blocked from editing. 83',
b'Image:MBizLogo.jpg listed for deletion \nAn image or media file that you uploaded or altered, Image:MBizLogo.jpg, has been listed at Wikipedia:Images and media for deletion. Please look there to see why this is (you may have to search for the title of the image to find its entry), if you are interested in it not being deleted. Thank you. talk',
b'the attention of American film makers.',
b'"\n\n ==The SP\xc3\x96 in Mein Kampf \n\nAt one moment Hitler lauded the Social Democratic Party in his mind (Friedrich Austerlitz, Anton David, Viktor Adler and Wilhelm Ellenbogen are mentioned by their last names on p. 66 of Mein Kampf) for championing ""das allgemeine und geheime Wahlrecht"" (Mein Kampf, p. 39) or universal suffrage, saying to himself that this must lead to a weakening of the so-hated-by-him Habsburg monarchic rule. When he later saw more of what the party was about however, he suggested its members were being servile to the Slavs to the point of debasing themselves, as one who is utterly mendicant, whilst disfavoring the Germans in an attempt to save the multiethnic melting pot of Austria-Hungary. It appears he was angered that he was asked to join the trade union and that men of Jewish heritage were influential in the party (Mein Kampf, pgs. 39, 40, 65 and 66).Mein Kampf, Erster Band, 2. Kapitel "',
b"Robert Young was not a creationist. Don't troll my page.",
b'Adding the Nazi party to Infobox politician under the heading Minister-President of Baden-W\xc3\xbcrttemberg, which makes it look like he represented the Nazi party as Minister-President, may easily be interpreted as vandalism. The Nazi party does not belong to the infobox for very obvious reasons stated at the article talk page.',
b'18 October 2010 (UTC)',
b"Just post the damn meme \n\nLook, Wikipedia is a bunch of pompous bullshit that's good for trolling and not much else, so let's throw the arbitrary standards out the window. Srsly, you come down on vandalism (real or imagined) like a ton of bricks within a minute, but how long did it take you to notice the wikipedos trying to use wikipedia to advocate raping kids? Kind of defeats the purpose of a hive mind if it can't stop shit like that.",
b'"\n We don\'t just use ""website content"", many players who played pre-1990s have many paper sources. It\'s just natural than modern footballers and modern international foreign players are sourced by internet links. If there\'s a particular issue with the reliability of sources on these articles then that\'s a different discussion altogether. "',
b'Please stop. If you continue to blank out or delete portions of page content, templates or other materials from Wikipedia, you will be blocked from editing. (t|p|c)',
b'Northeastern State University \n\nHello, Ya im a student at NSU, When I started Wikipedia NSU was the first article that I began to imporve and have had no help with it so anything you would like to do would be great. If you think the photos need to increase then go for it, but I would not make them two big so that it doesnt slow down computers with a slower internet connection or look bad on small screens. I redid the info box and modeled it after the OU page. I to have wanted to add info about BA and Muskogee for a while, but I havent found much or got around to it. If you wanna add to these parts of the article ill help you out if you need it. I also will try to add some things myself to make it better.',
b"Mr Creakle \n\nI personally think that this edit shows Mr. Creakle is too favourable a light. More so than he should be. He is, after all, one the bane's of David's school days. Thoughts?",
b'i am not vandalizing your talk page you stupid bitch\n\nyou must have some other enemies, now do yourself a favor and fuck off',
b'"Also what you claim as FACTS are not FACTS, do you even read the stuff you post?\n\nThis is a subject that is still controversial, both as to the area of origin and the routes\nbetween that area and the Philippines. There are two opposing hypotheses on both\nsubjects. Both hypotheses are agreed that the differing major languages and the many\ndifferent dialects spoken in the Philippines at the time of Portuguese and Spanish entry to\nthe Philippines in the early 15th century were all of the Austronesian Super Family. That\nis where the agreement ends.\nIn the Philippines, the two best known hypotheses are those of Peter Bellwood and\nmine (Solheim). While my hypothesis\xe2\x80\x94which will be the main subject of this\npresentation\xe2\x80\x94started developing over 40 years ago, Bellwood\xe2\x80\x99s is more recently\ndeveloped and has had much greater acceptance outside of the Philippines. Simply put,\nhe believes that ProtoAustronesian\noriginated in eastern South China opposite Taiwan\nand was brought from there to Taiwan by maritime Chinese bringing with them rice\nagriculture; this around 7000 years ago.\n\nYour claims on facts are based on what? Obviously everything you state and post are just your opinions. You can\'t even show a source on how Filipinos having Malay background is false, you just simply state what you ""know"".\n\n"',
b'"\n Do you have a reference that I can look at? worm "',
b'I think you are forgetting the simple fact that Frau Rathlef and Gleb Botkin were both first hand witnesses to the events that took place. They ARE the sources. 72.25.99.30',
b"I honestly believe you're that same exact",
b'I am going to kill you I am going to murder you',
b"In addition \nMain article: Wikipedia:Biographies of living persons\n'When writing about controversies in the lead of the biography of a living person, notable material should neither be suppressed nor allowed to overwhelm' The blockquotes *do* overwhelm the article.",
b"Further Developments\nIt should be included that the girl('s body) was moved by ground ambulance with the assistance of the Terry Schiavo Foundation, first to an undisclosed Catholic Charity Medical Center in the Greater San Francisco Bay Area for doctors there to re-install feeding and hydration tubes, and then subsequently transferred to a still-further facility whose identity also remains a mystery where she is reported to be ``improving'' after suffering nearly a month with neither feeding nor any hydration other than those contained in the IV fluids.",
b"Manufacturer/AKA \n\nPrinceGloria, please go read the Infobox conventions. The Vibe's Manufacturer is Pontiac, not the NUMMI plant. Also, no one ever agreed that the Matrix/Vibe were related. Please read this discussion again. You are the only person who contended that the Matrix and Vibe were related and not AKA.",
b'"\nSylwiaS wrote ""I think .. "" moje gratulacje WOW , COOL ..etc"',
b"\xe2\x80\xa6 \n\nauthored: past participle of author\n\nwritten: past participle of write\n\nwrite: be the author of\n\nSince they mean the same thing, I don't care that it's presently written. \xc2\xa6",
b'Cultural Signicance \n\nThe article totally lacks any information on the cultural and literary significance of the plant. As such it is incomplete and for my current purposes useless.',
b'John Lavelle, saving the world via Wikipedia!',
b'"\n I can provide releant diff at opportune and appropriate time. Meanwhile, I don\'t agree with your assessment as to which books are ""most reliable"". if you\'re not acquainted with Ralph Bennett, then we\'re not on the same page. I\'d recommmend his work Ultra and Mediterraen Strategy, and also his article ""Ultra and some command decisions"", Journal of Contemporary History, Vol 16, 1981. \n\n As for intel provided to Soviets: I\'m fairly certain Churchill\'s selective provision of Ulta intel to the Soviets was strategically self-serving. Also, it\'s one thing to provide someone with intel that YOU surmise he needs, and another thing entirely to provide intel that he knows he needs. Far as article is concerned, the point is that 10 years after the article\'s first appearance, reliable Hinsley and Bennett remain conspicuously absent from references, (as is conspicuously absent from the article content the crucial matter of strategy per se); whereas there\'s a preponderence of minutae about mostly side-show issues, and what one editor has aptly described as ""crappy POV pushing references."" In other words, sloppy and partisan editing, whatever the reason or reasons for it. "',
b'"There is no other references on the internet to ""Lepore syndrome"". The correct medical term is ""Hemoglobin Lepore Syndrome"" which can be abbreviated to Hb Lepore Syndrome ( Hb is an abbreviation of hemoglobin). "',
b"Alright Francophonie, a la prochaine. Hey, it's late here: care to hold the fort? I know you're going to invite your young and lively friends in, but please don't let them take the silver or break the china.",
b"'' as used in U.S. discourse on society and ethnicity"],
dtype=object), array([[1, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 1, 0],
[0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]))
And we check also the shape. We expect a feature of shape (batch, ) and a target of shape (batch, number of labels).
text train shape: (32,)
text train type: object
label train shape: (32, 6)
label train type: int64
Of course preprocessing! Text is not the type of input a NN can handle. The TextVectorization layer is meant to handle natural language inputs. The processing of each example contains the following steps: 1. Standardize each example (usually lowercasing + punctuation stripping) 2. Split each example into substrings (usually words) 3. Recombine substrings into tokens (usually ngrams) 4. Index tokens (associate a unique int value with each token) 5. Transform each example using this index, either into a vector of ints or a dense float vector.
For more reference, see the documentation at the following link.
text_vectorization = TextVectorization(
max_tokens=config.max_tokens,
standardize="lower_and_strip_punctuation",
split="whitespace",
output_mode="int",
output_sequence_length=config.output_sequence_length,
pad_to_max_tokens=True
)
# prepare a dataset that only yields raw text inputs (no labels)
text_train_ds = train_ds.map(lambda x, y: x)
# adapt the text vectorization layer to the text data to index the dataset vocabulary
text_vectorization.adapt(text_train_ds)This layer is set to: - max_tokens: 20000. It is common for text classification. It is the maximum size of the vocabulary for this layer. - output_sequence_length: 911. See Figure 3 for the reason why. Only valid in "int" mode. - output_mode: outputs integer indices, one integer index per split string token. When output_mode == “int”, 0 is reserved for masked locations; this reduces the vocab size to max_tokens - 2 instead of max_tokens - 1. - standardize: "lower_and_strip_punctuation". - split: on whitespace.
To preserve the original comments as text and also have a tf.data.Dataset in which the text is preprocessed by the TextVectorization function, it is possible to map it to the features of each dataset.
processed_train_ds = train_ds.map(
lambda x, y: (text_vectorization(x), y),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_val_ds = val_ds.map(
lambda x, y: (text_vectorization(x), y),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_test_ds = test_ds.map(
lambda x, y: (text_vectorization(x), y),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)Define the model using the Functional API.
def get_deeper_lstm_model():
clear_session()
inputs = Input(shape=(None,), dtype=tf.int64, name="inputs")
embedding = Embedding(
input_dim=config.max_tokens,
output_dim=config.embedding_dim,
mask_zero=True,
name="embedding"
)(inputs)
x = Bidirectional(LSTM(256, return_sequences=True, name="bilstm_1"))(embedding)
x = Bidirectional(LSTM(128, return_sequences=True, name="bilstm_2"))(x)
# Global average pooling
x = GlobalAveragePooling1D()(x)
# Add regularization
x = Dropout(0.3)(x)
x = Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
x = LayerNormalization()(x)
outputs = Dense(len(config.labels), activation='sigmoid', name="outputs")(x)
model = Model(inputs, outputs)
model.compile(optimizer='adam', loss="binary_crossentropy", metrics=config.metrics, steps_per_execution=32)
return model
lstm_model = get_deeper_lstm_model()
lstm_model.summary()Finally, the model has been trained using 2 callbacks: - Early Stopping, to avoid to consume the kaggle GPU time. - Model Checkpoint, to retrieve the best model training information.
Considering the dataset is imbalanced, to increase the performance we need to calculate the class weight. This will be passed during the training of the model.
class_weight
toxic 0.095900590
severe_toxic 0.009928468
obscene 0.052757858
threat 0.003061800
insult 0.049132042
identity_hate 0.008710911
It is also useful to define the steps per epoch for train and validation dataset. This step is required to avoid to not consume entirely the dataset during the fit, which happened to me.
The fit has been done on Kaggle to levarage the GPU. Some considerations about the model:
.repeat() ensure the model sees all the dataset.epocs is set to 100.validation_data has the same repeat.callbacks are the one defined before.class_weight ensure the model is trained using the frequency of each class, because our dataset is imbalanced.steps_per_epoch and validation_steps depend on the use of repeat.Now we can import the model and the history trained on Kaggle.
# A tibble: 5 × 2
metric value
<chr> <dbl>
1 loss 0.0542
2 precision 0.789
3 recall 0.671
4 auc 0.957
5 f1_score 0.0293
For the prediction, the model does not need to repeat the dataset, because it has already been trained on all of the train data. Now it has just to consume the new data to make the prediction.
The best way to assess the performance of a multi label classification is using a confusion matrix. Sklearn has a specific function to create a multi label classification matrix to handle the fact that there could be multiple labels for one prediction.
Grid Search CV is a technique for fine-tuning hyperparameter of a ML model. It systematically search through a set of hyperparamenter values to find the combination which led to the best model performance. In this case, I am using a KFold Cross Validation is a resempling technique to split the data into k consecutive folds. Each fold is used once as a validation while the k - 1 remaining folds are the training set. See the documentation for more information.
The model is trained to optimize the recall. The decision was made because the cost of missing a True Positive is greater than a False Positive. In this case, missing a injurious observation is worst than classifying a clean one as bad.
Whilst the KFold GDCV technique is usefull to test multiple hyperparameter, it is important to understand the problem we are facing. A multi label deep learning classifier outputs a vector of per-class probabilities. These need to be converted to a binary vector using a confidence threshold.
Threshold selection mean we have to decide which metric to prioritize, based on the problem we are facing and the relative cost of misduging. We can consider the toxic comment filtering a problem similiar to cancer diagnostic. It is better to predict cancer in people who do not have it [False Positive] and perform further analysis than do not predict cancer when the patient has the disease [False Negative].
I decide to train the model on the F1 score to have a balanced model in both precision and recall and leave to the threshold selection to increase the recall performance.
Moreover, the model has been trained on the macro avarage F1 score, which is a single performance indicator obtained by the mean of the Precision and Recall scores of individual classses.
It is usegule for imbalanced classes, because it weights each classes equally. It is not influenced by the number of samples of each classes. This is sette both in the config.metrics and find_optimal_threshold_cv.
Optimal threshold: 0.15000000000000002
Best score: 0.4788653077945807
Optimal threshold f1 score: 0.15. Best score: 0.4788653.
Optimal threshold recall: 0.05. Best score: 0.8095814.
Optimal threshold: 0.05
Best score: 0.8809499649742268
Optimal threshold roc: 0.05. Best score: 0.88095.
# convert probability predictions to predictions
ypred = predictions >= optimal_threshold_recall # .05
ypred = ypred.astype(int)
# create a plot with 3 by 2 subplots
fig, axes = plt.subplots(3, 2, figsize=(15, 15))
axes = axes.flatten()
mcm = multilabel_confusion_matrix(ytrue, ypred)
# plot the confusion matrices for each label
for i, (cm, label) in enumerate(zip(mcm, config.labels)):
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(ax=axes[i], colorbar=False)
axes[i].set_title(f"Confusion matrix for label: {label}")
plt.tight_layout()
plt.show()
# A tibble: 10 × 5
metrics precision recall `f1-score` support
<chr> <dbl> <dbl> <dbl> <dbl>
1 toxic 0.552 0.890 0.682 2262
2 severe_toxic 0.236 0.917 0.375 240
3 obscene 0.550 0.936 0.692 1263
4 threat 0.0366 0.493 0.0681 69
5 insult 0.471 0.915 0.622 1170
6 identity_hate 0.116 0.720 0.200 207
7 micro avg 0.416 0.896 0.569 5211
8 macro avg 0.327 0.812 0.440 5211
9 weighted avg 0.495 0.896 0.629 5211
10 samples avg 0.0502 0.0848 0.0597 5211
The BiLSTM model is optimized to have an high recall is performing good enough to make predictions for each label. Considering the low support for the threat label, the performance is not bad. See Table 2 and Figure 1: the threat label is only 0.27 % of the observations. The model has been optimized for recall because the cost of not identifying a injurious comment as such is higher than the cost of considering a clean comment as injurious.
Possibile improvements could be to increase the number of observations, expecially for the threat one. In general there are too many clean comments. This could be avoided doing an undersampling of the clean comment, which I explicitly avoided to check the performance on the BiLSTM with an imbalanced dataset, leveraging the class weight method.